# Copyright 2017 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Helpers related to multiprocessing."""

import atexit
import logging
import multiprocessing
import multiprocessing.dummy
import os
import sys
import threading
import traceback


DISABLE_ASYNC = os.environ.get('SUPERSIZE_DISABLE_ASYNC') == '1'
if DISABLE_ASYNC:
  logging.debug('Running in synchronous mode.')

_all_pools = None
_is_child_process = False
_silence_exceptions = False


class _ImmediateResult(object):
  def __init__(self, value):
    self._value = value

  def get(self):
    return self._value

  def wait(self):
    pass

  def ready(self):
    return True

  def successful(self):
    return True


class _ExceptionWrapper(object):
  """Used to marshal exception messages back to main process."""
  def __init__(self, msg):
    self.msg = msg


class _FuncWrapper(object):
  """Runs on the fork()'ed side to catch exceptions and spread *args."""
  def __init__(self, func):
    global _is_child_process
    _is_child_process = True
    self._func = func

  def __call__(self, args, _=None):
    try:
      return self._func(*args)
    except:  # pylint: disable=bare-except
      # multiprocessing is supposed to catch and return exceptions automatically
      # but it doesn't seem to work properly :(.
      logging.warning('CAUGHT EXCEPTION')
      return _ExceptionWrapper(traceback.format_exc())


class _WrappedResult(object):
  """Allows for host-side logic to be run after child process has terminated.

  * Unregisters associated pool _all_pools.
  * Raises exception caught by _FuncWrapper.
  * Allows for custom unmarshalling of return value.
  """
  def __init__(self, result, pool=None, decode_func=None):
    self._result = result
    self._pool = pool
    self._decode_func = decode_func

  def get(self):
    self.wait()
    value = self._result.get()
    _CheckForException(value)
    if not self._decode_func or not self._result.successful():
      return value
    return self._decode_func(value)

  def wait(self):
    self._result.wait()
    if self._pool:
      _all_pools.remove(self._pool)
      self._pool = None

  def ready(self):
    return self._result.ready()

  def successful(self):
    return self._result.successful()


def _TerminatePools():
  """Calls .terminate() on all active process pools.

  Not supposed to be necessary according to the docs, but seems to be required
  when child process throws an exception or Ctrl-C is hit.
  """
  global _silence_exceptions
  _silence_exceptions = True
  # Child processes cannot have pools, but atexit runs this function because
  # it was registered before fork()ing.
  if _is_child_process:
    return
  def close_pool(pool):
    try:
      pool.terminate()
    except:  # pylint: disable=bare-except
      pass

  for i, pool in enumerate(_all_pools):
    # Without calling terminate() on a separate thread, the call can block
    # forever.
    thread = threading.Thread(name='Pool-Terminate-{}'.format(i),
                              target=close_pool, args=(pool,))
    thread.daemon = True
    thread.start()


def _CheckForException(value):
  if isinstance(value, _ExceptionWrapper):
    global _silence_exceptions
    if not _silence_exceptions:
      _silence_exceptions = True
      logging.error('Subprocess raised an exception:\n%s', value.msg)
    sys.exit(1)


def _MakeProcessPool(*args):
  global _all_pools
  ret = multiprocessing.Pool(*args)
  if _all_pools is None:
    _all_pools = []
    atexit.register(_TerminatePools)
  _all_pools.append(ret)
  return ret


def ForkAndCall(func, args, decode_func=None):
  """Runs |func| in a fork'ed process.

  Returns:
    A Result object (call .get() to get the return value)
  """
  if DISABLE_ASYNC:
    pool = None
    result = _ImmediateResult(func(*args))
  else:
    pool = _MakeProcessPool(1)
    result = pool.apply_async(_FuncWrapper(func), (args,))
    pool.close()
  return _WrappedResult(result, pool=pool, decode_func=decode_func)


def BulkForkAndCall(func, arg_tuples):
  """Calls |func| in a fork'ed process for each set of args within |arg_tuples|.

  Yields the return values as they come in.
  """
  pool_size = min(len(arg_tuples), multiprocessing.cpu_count())
  if DISABLE_ASYNC:
    for args in arg_tuples:
      yield func(*args)
    return
  pool = _MakeProcessPool(pool_size)
  wrapped_func = _FuncWrapper(func)
  for result in pool.imap_unordered(wrapped_func, arg_tuples):
    _CheckForException(result)
    yield result
  pool.close()
  pool.join()
  _all_pools.remove(pool)


def CallOnThread(func, *args, **kwargs):
  """Calls |func| on a new thread and returns a promise for its return value."""
  if DISABLE_ASYNC:
    return _ImmediateResult(func(*args, **kwargs))
  pool = multiprocessing.dummy.Pool(1)
  result = pool.apply_async(func, args=args, kwds=kwargs)
  pool.close()
  return result


def EncodeDictOfLists(d, key_transform=None):
  """Serializes a dict where values are lists of strings."""
  keys = iter(d)
  if key_transform:
    keys = (key_transform(k) for k in keys)
  keys = '\x01'.join(keys)
  values = '\x01'.join('\x02'.join(x) for x in d.itervalues())
  return keys, values


def DecodeDictOfLists(encoded_keys, encoded_values, key_transform=None):
  """Deserializes a dict where values are lists of strings."""
  keys = encoded_keys.split('\x01')
  if key_transform:
    keys = (key_transform(k) for k in keys)
  values = encoded_values.split('\x01')
  ret = {}
  for i, key in enumerate(keys):
    ret[key] = values[i].split('\x02')
  return ret
